{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "g4jtzXwEvW2-"
   },
   "outputs": [],
   "source": [
    "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
    "\n",
    "# Licensed to the Apache Software Foundation (ASF) under one\n",
    "# or more contributor license agreements. See the NOTICE file\n",
    "# distributed with this work for additional information\n",
    "# regarding copyright ownership. The ASF licenses this file\n",
    "# to you under the Apache License, Version 2.0 (the\n",
    "# \"License\"); you may not use this file except in compliance\n",
    "# with the License. You may obtain a copy of the License at\n",
    "#\n",
    "#   http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing,\n",
    "# software distributed under the License is distributed on an\n",
    "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
    "# KIND, either express or implied. See the License for the\n",
    "# specific language governing permissions and limitations\n",
    "# under the License."
   ],
   "id": "g4jtzXwEvW2-"
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HtysPAVSvcMg"
   },
   "source": [
    "# 🌦️ Weather forecasting -- _Training_\n",
    "\n",
    "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/GoogleCloudPlatform/python-docs-samples/blob/main/people-and-planet-ai/weather-forecasting/notebooks/3-training.ipynb)\n",
    "\n",
    "This sample is broken into the following notebooks:\n",
    "\n",
    "* [![Open in Colab](https://github.com/googlecolab/open_in_colab/raw/main/images/icon16.png) **🧭 Overview**](https://colab.research.google.com/github/GoogleCloudPlatform/python-docs-samples/blob/main/people-and-planet-ai/weather-forecasting/notebooks/1-overview.ipynb):\n",
    "  Go through what we want to achieve, and explore the data we want to use as _inputs and outputs_ for our model.\n",
    "\n",
    "* [![Open in Colab](https://github.com/googlecolab/open_in_colab/raw/main/images/icon16.png) **🗄️ Create the dataset**](https://colab.research.google.com/github/GoogleCloudPlatform/python-docs-samples/blob/main/people-and-planet-ai/weather-forecasting/notebooks/2-dataset.ipynb):\n",
    "  Use [Apache Beam](https://beam.apache.org/) to fetch data from [Earth Engine](https://earthengine.google.com/) in parallel, and create a dataset for our model in [Dataflow](https://cloud.google.com/dataflow).\n",
    "\n",
    "* ![Open in Colab](https://github.com/googlecolab/open_in_colab/raw/main/images/icon16.png) **🧠 Train the model**:\n",
    "  Build a simple _Fully Convolutional Network_ in [PyTorch](https://pytorch.org/) and train it in [Vertex AI](https://cloud.google.com/vertex-ai/docs/training/custom-training) with the dataset we created.\n",
    "\n",
    "* [![Open in Colab](https://github.com/googlecolab/open_in_colab/raw/main/images/icon16.png) **🔮 Model predictions**](https://colab.research.google.com/github/GoogleCloudPlatform/python-docs-samples/blob/main/people-and-planet-ai/weather-forecasting/notebooks/4-predictions.ipynb):\n",
    "  Get predictions from the model with data it has never seen before.\n",
    "\n",
    "This sample leverages geospatial satellite and precipitation data from [Google Earth Engine](https://earthengine.google.com/).\n",
    "Using satellite imagery, you'll build and train a model for rain \"nowcasting\" i.e. predicting the amount of rainfall for a given geospatial region and time in the immediate future.\n",
    "\n",
    "* ⏲️ **Time estimate**: ~40 minutes\n",
    "* 💰 **Cost estimate**: [a few cents on Vertex AI](https://cloud.google.com/vertex-ai/pricing#custom-trained_models)\n",
    "\n",
    "💚 This is one of many **machine learning how-to samples** inspired from **real climate solutions** aired on the [People and Planet AI 🎥 series](https://www.youtube.com/playlist?list=PLIivdWyY5sqI-llB35Dcb187ZG155Rs_7)."
   ],
   "id": "HtysPAVSvcMg"
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RuFZck60B8t-"
   },
   "source": [
    "# 🎬 Before you begin\n",
    "\n",
    "Let's start by cloning the GitHub repository, and installing some dependencies."
   ],
   "id": "RuFZck60B8t-"
  },
  {
   "cell_type": "code",
   "source": [
    "# Now let's get the code from GitHub and navigate to the sample.\n",
    "!git clone https://github.com/GoogleCloudPlatform/python-docs-samples.git\n",
    "%cd python-docs-samples/people-and-planet-ai/weather-forecasting"
   ],
   "metadata": {
    "id": "W-fPxkYD9FaP"
   },
   "execution_count": null,
   "outputs": [],
   "id": "W-fPxkYD9FaP"
  },
  {
   "cell_type": "markdown",
   "source": [
    "The [`weather-model`](../serving/weather-model) local package contains the model definition and the training script.\n",
    "This ensures we use the same model definition for both training and predictions.\n"
   ],
   "metadata": {
    "id": "r5OijZcuInAe"
   },
   "id": "r5OijZcuInAe"
  },
  {
   "cell_type": "code",
   "source": [
    "# Upgrade `setuptools` to install packages from pyproject.toml files.\n",
    "!pip install --quiet --upgrade --no-warn-conflicts pip setuptools\n",
    "\n",
    "# We need `build` and `virtualenv` to build the local packages.\n",
    "!pip install --quiet build virtualenv\n",
    "\n",
    "# Install the `weather-model` local package.\n",
    "!pip install google-cloud-aiplatform serving/weather-model"
   ],
   "metadata": {
    "id": "AlcsK6pd-x0I"
   },
   "execution_count": null,
   "outputs": [],
   "id": "AlcsK6pd-x0I"
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G75Y6HszxBL8"
   },
   "source": [
    "> **🛑 Restart the runtime 🛑**\n",
    "\n",
    "Colab already comes with many dependencies pre-loaded.\n",
    "In order to ensure everything runs as expected, we **_must_ restart the runtime**. This allows Colab to load the latest versions of the libraries.\n",
    "\n",
    "![\"Runtime\" > \"Restart runtime\"](images/restart-runtime.png)"
   ],
   "id": "G75Y6HszxBL8"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xGXRHJ9TFs24"
   },
   "outputs": [],
   "source": [
    "# Alternatively, restart the runtime by ending the process.\n",
    "exit()"
   ],
   "id": "xGXRHJ9TFs24"
  },
  {
   "cell_type": "markdown",
   "source": [
    "After restarting the runtime, let's navigate back into the sample directory."
   ],
   "metadata": {
    "id": "WI_vvBpPD4tr"
   },
   "id": "WI_vvBpPD4tr"
  },
  {
   "cell_type": "code",
   "source": [
    "%cd python-docs-samples/people-and-planet-ai/weather-forecasting"
   ],
   "metadata": {
    "id": "6fdyXMdlD3cz",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "457ec966-1f2a-4e76-df1b-62d7d9d77e60"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[Errno 2] No such file or directory: 'python-docs-samples/people-and-planet-ai/weather-forecasting'\n",
      "/content/python-docs-samples/people-and-planet-ai/weather-forecasting/python-docs-samples/people-and-planet-ai/weather-forecasting\n"
     ]
    }
   ],
   "id": "6fdyXMdlD3cz"
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mHvEEW6oyFGV"
   },
   "source": [
    "## ☁️ My Google Cloud resources\n",
    "\n",
    "Make sure you have followed these steps to configure your Google Cloud project:\n",
    "\n",
    "1. Enable the APIs: _Vertex AI_\n",
    "\n",
    "  <button>\n",
    "\n",
    "  [Click here to enable the APIs](aiplatform.googleapis.com)\n",
    "  </button>\n",
    "\n",
    "1. Create or use an existing Cloud Storage bucket.\n",
    "\n",
    "  <button>\n",
    "\n",
    "  [Click here to create a new Cloud Storage bucket](https://console.cloud.google.com/storage/create-bucket)\n",
    "  </button>\n",
    "\n",
    "Once you have everything ready, you can go ahead and fill in your Google Cloud resources in the following code cell.\n",
    "Make sure you run it!"
   ],
   "id": "mHvEEW6oyFGV"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "YMPNUR0pyRvy"
   },
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import os\n",
    "from google.colab import auth\n",
    "\n",
    "# Please fill in these values.\n",
    "project = \"\"  # @param {type:\"string\"}\n",
    "bucket = \"\"  # @param {type:\"string\"}\n",
    "location = \"us-central1\"  # @param {type:\"string\"}\n",
    "\n",
    "# Quick input validations.\n",
    "assert project, \"⚠️ Please provide a Google Cloud project ID\"\n",
    "assert bucket, \"⚠️ Please provide a Cloud Storage bucket name\"\n",
    "assert not bucket.startswith(\n",
    "    \"gs://\"\n",
    "), f\"⚠️ Please remove the gs:// prefix from the bucket name: {bucket}\"\n",
    "assert location, \"⚠️ Please provide a Google Cloud location\"\n",
    "\n",
    "# Authenticate to Colab.\n",
    "auth.authenticate_user()\n",
    "\n",
    "# Set GOOGLE_CLOUD_PROJECT for google.auth.default().\n",
    "os.environ[\"GOOGLE_CLOUD_PROJECT\"] = project\n",
    "\n",
    "# Set the gcloud project for other gcloud commands.\n",
    "!gcloud config set project {project}"
   ],
   "id": "YMPNUR0pyRvy"
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "02b1b9dd"
   },
   "source": [
    "# 🧠 Train the model locally\n",
    "\n",
    "We need our model for both training and for prediction.\n",
    "So we created the local [`weather-model`](../serving/weather-model) module.\n",
    "It contains [`weather/model.py`](../serving/weather-model/weather/model.py) where the model is defined, and [`weather/trainer.py`](../serving/weather-model/weather/trainer.py) where all the training code lives."
   ],
   "id": "02b1b9dd"
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PY5H3OMjfVAR"
   },
   "source": [
    "## 📖 Read the dataset\n",
    "\n",
    "Unfortunately, PyTorch cannot read files from Cloud Storage out of the box.\n",
    "Fortunately, Vertex AI uses [Cloud Storage FUSE](https://cloud.google.com/blog/products/ai-machine-learning/cloud-storage-file-system-ai-training) to mount and access Cloud Storage files as if they were local files.\n",
    "\n",
    "For now, let's download the data files we created in the [🗄️ **Create the dataset**](https://colab.research.google.com/github/GoogleCloudPlatform/python-docs-samples/blob/main/people-and-planet-ai/weather-forecasting/notebooks/2-dataset.ipynb) notebook to have them locally."
   ],
   "id": "PY5H3OMjfVAR"
  },
  {
   "cell_type": "code",
   "source": [
    "data_path_gcs = f\"gs://{bucket}/weather/data\"\n",
    "\n",
    "!mkdir -p data-training\n",
    "!gsutil -m cp {data_path_gcs}/* data-training"
   ],
   "metadata": {
    "id": "h_IUpnqvO-sa"
   },
   "id": "h_IUpnqvO-sa",
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "First, we need to load the dataset to feed it to the model.\n",
    "To read a dataset in PyTorch, we could manually instantiate a subclass of `torch.utils.data.Dataset`, but we're going to use [Hugging Face 🤗 Datasets](https://huggingface.co/docs/datasets/main/en/index), which are a high-level interface to use datasets more easily.\n",
    "\n",
    "Our data files are compressed NumPy files, which we can easily load with NumPy.\n",
    "To load them into a 🤗 Dataset, we can use [`Dataset.from_dict`](https://huggingface.co/docs/datasets/main/en/loading#python-dictionary) and pass it a dictionary containing all the file names of our data files.\n",
    "Then, we use [`Dataset.map`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.map) to read the data files and processs the examples in parallel.\n",
    "Additionally, we _augment_ the data by rotating and flipping each example.\n",
    "To split the our dataset into training and a testing/validation subsets, we use [`Dataset.train_test_split`](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.train_test_split).\n",
    "\n",
    "In [`weather/trainer.py`](../serving/weather-model/weather/trainer.py) we defined the `read_dataset` function to load our data files, and returns us a 🤗 Dataset with train/test splits."
   ],
   "metadata": {
    "id": "Pl3qbyggO7rR"
   },
   "id": "Pl3qbyggO7rR"
  },
  {
   "cell_type": "code",
   "source": [
    "from weather.trainer import read_dataset\n",
    "\n",
    "data_path = \"data-training\"\n",
    "train_test_ratio = 0.9  # 90% train, 10% test\n",
    "\n",
    "# Read the dataset with train/test splits.\n",
    "dataset = read_dataset(data_path, train_test_ratio)"
   ],
   "metadata": {
    "id": "rxwvw7ihacXy"
   },
   "id": "rxwvw7ihacXy",
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "print(dataset)"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ItTBWR98dByh",
    "outputId": "8c522562-9937-4dc5-e482-8bec3cdba277"
   },
   "id": "ItTBWR98dByh",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['inputs', 'labels'],\n",
      "        num_rows: 3069\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['inputs', 'labels'],\n",
      "        num_rows: 341\n",
      "    })\n",
      "})\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "> 💡 For more information on loading data into a 🤗 Dataset, refer to the [Loading data](https://huggingface.co/docs/datasets/main/en/loading) guide.\n",
    "\n",
    "🤗 Datasets allow for random access just like PyTorch Datasets.\n",
    "\n",
    "Let's see the shapes of the first training example from the `train` split.\n",
    "When we access an example, we get an `{'inputs': list, 'labels': list}` dictionary, where each value is a [Python list](https://docs.python.org/3/library/stdtypes.html#list).\n",
    "We can then convert them into [PyTorch tensors](https://pytorch.org/docs/stable/tensors.html) for further use."
   ],
   "metadata": {
    "id": "jnlg80Tl4QLS"
   },
   "id": "jnlg80Tl4QLS"
  },
  {
   "cell_type": "code",
   "source": [
    "import torch\n",
    "\n",
    "train_dataset = dataset[\"train\"]\n",
    "example = train_dataset[0]  # random access the first element\n",
    "\n",
    "print(f\"inputs: {torch.as_tensor(example['inputs']).shape}\")\n",
    "print(f\"labels: {torch.as_tensor(example['labels']).shape}\")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Ji67MIKQ58Zr",
    "outputId": "7e98452a-2744-4a72-93e8-9a53e1f6b695"
   },
   "id": "Ji67MIKQ58Zr",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "inputs: torch.Size([5, 5, 52])\n",
      "labels: torch.Size([5, 5, 2])\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "The _inputs_ have the shape `(width, height, num_inputs)`, where each input is the value of an Earth Engine band.\n",
    "\n",
    "The _outputs_ have the shape `(width, height, num_outputs)`, where each output is a prediction.\n",
    "We're predicting for 2 and 6 hours into the future, so we get 2 outputs."
   ],
   "metadata": {
    "id": "JWFJY1pv7T91"
   },
   "id": "JWFJY1pv7T91"
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oQnLpK0OmutA"
   },
   "source": [
    "## 📓 Define the model\n",
    "\n",
    "First we define our model, which is a very simple _Fully Convolutional Network_.\n",
    "The input data can consist of potentially very large numbers, but machine learning generally prefers small numbers around -1 and 1.\n",
    "So in [`weather/model.py`](../weather/model.py) we defined a `Normalization` layer which applies [Z-Score](https://developers.google.com/machine-learning/data-prep/transform/normalization#z-score) to normalize all the model's inputs as a first step.\n",
    "But we need to provide it with the [_mean_](https://en.wikipedia.org/wiki/Mean) and [_standard deviation_](https://en.wikipedia.org/wiki/Standard_deviation) from the training dataset.\n",
    "\n",
    "A model always processes _batches_ of inputs, so we always get an extra _first_ dimension.\n",
    "This means that for all the layers in the model, our inputs have the shape `(batch, width, height, num_inputs)`, and our outputs have the shape `(batch, width, height, num_outputs)`.\n",
    "\n",
    "We need to calculate the mean and standard deviation for each input, so each band is normalized within its own range.\n",
    "Both the mean and standard deviation must have the shape `(batch, width, height, num_inputs)`, which allows them to _broadcast_ to any batch size, width and height, as long as the `num_inputs` match."
   ],
   "id": "oQnLpK0OmutA"
  },
  {
   "cell_type": "code",
   "source": [
    "import numpy as np\n",
    "\n",
    "# Let's get the mean and standard deviation.\n",
    "data = np.array(dataset[\"train\"][\"inputs\"], np.float32)\n",
    "mean = data.mean(axis=(0, 1, 2))[None, None, None, :]\n",
    "std = data.std(axis=(0, 1, 2))[None, None, None, :]\n",
    "\n",
    "print(f\"mean: {mean.shape}\")\n",
    "print(f\"std:  {std.shape}\")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "YkOwsJBuYIHg",
    "outputId": "4a3ed264-5dca-4169-bee9-9e4ecaea5409"
   },
   "id": "YkOwsJBuYIHg",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "mean: (1, 1, 1, 52)\n",
      "std:  (1, 1, 1, 52)\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "Let's see how the normalization works for a sample of an example's inputs."
   ],
   "metadata": {
    "id": "meHaHpxW-zt5"
   },
   "id": "meHaHpxW-zt5"
  },
  {
   "cell_type": "code",
   "source": [
    "import torch\n",
    "\n",
    "from weather.model import Normalization\n",
    "\n",
    "normalization = Normalization(mean, std)\n",
    "\n",
    "sample = lambda x: x[0, 0, 0, 10:15].detach().numpy()\n",
    "\n",
    "print(f\"mean: {sample(normalization.mean)}\")\n",
    "print(f\"std:  {sample(normalization.std)}\")\n",
    "print(\"-\" * 40)\n",
    "\n",
    "example = dataset[\"train\"][0]\n",
    "example_inputs = torch.as_tensor([example[\"inputs\"]])\n",
    "normalized_inputs = normalization(example_inputs)\n",
    "print(f\"inputs:     {sample(example_inputs)}\")\n",
    "print(f\"normalized: {sample(normalized_inputs)}\")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EUT8fowo-_Bv",
    "outputId": "5fc2a64b-7dde-4803-9d21-34264bcf93f5"
   },
   "id": "EUT8fowo-_Bv",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "mean: [2202.3132 2355.514  2328.052  2470.9158 2687.0806]\n",
      "std:  [256.82922 324.5936  332.1437  480.68338 351.21927]\n",
      "----------------------------------------\n",
      "inputs:     [2295. 2514. 2534. 2774. 2957.]\n",
      "normalized: [0.36088872 0.48826003 0.6200569  0.6305278  0.76852113]\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "After applying the `Normalization` layer, we get small numbers much closer to the range within -1 and 1, they don't have to be _exactly_ within the range, just close enough.\n",
    "\n",
    "Another thing to note is that our data is in a channels-last format, like `(width, height, channels)`.\n",
    "But PyTorch expects channels-first format in the convolutional layers, like `(channels, width, height)`.\n",
    "We still want to pass our inputs in a channels-last format and want the predictions back as channels-last for convenience, but we must convert them to channels-first for PyTorch convolutional layers to work.\n",
    "\n",
    "In [`weather/model.py`](../serving/weather-model/weather/model.py) we define the `MoveDim` layer, which works similar to [`torch.movedim`](https://pytorch.org/docs/stable/generated/torch.movedim.html) so the model can move the channels dimension as needed.\n"
   ],
   "metadata": {
    "id": "Idvef7Id49vE"
   },
   "id": "Idvef7Id49vE"
  },
  {
   "cell_type": "code",
   "source": [
    "from weather.model import MoveDim\n",
    "\n",
    "# We move the channels/last dimension (-1) to the second index (1),\n",
    "# since the first (0) is for the batch dimension.\n",
    "to_channels_first = MoveDim(-1, 1)\n",
    "channels_first = to_channels_first(normalized_inputs)\n",
    "\n",
    "print(f\"normalized:     {normalized_inputs.shape}\")\n",
    "print(f\"channels-first: {channels_first.shape}\")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "AkrmnehOCuol",
    "outputId": "0a41f8d0-8f61-4946-92e1-67e33e3eddf1"
   },
   "id": "AkrmnehOCuol",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "normalized:     torch.Size([1, 5, 5, 52])\n",
      "channels-first: torch.Size([1, 52, 5, 5])\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "The model then passes the data through a\n",
    "[2D Convolutional layer](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) for downsampling, and then through a\n",
    "[2D DeConvolutional layer](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html) for upsampling, so we end up with images the same size as the input image.\n",
    "We used a [`ReLU`](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) activation function inbetween all hidden layers since it's typically a good general purpose activation function.\n",
    "\n",
    "The Conv2D and DeConv2D layers form a very simple Fully Convolutional Network architecture, and since we're using the same _kernel size_ for both we get the same `(width, height)` as outputs."
   ],
   "metadata": {
    "id": "6JpbxntkEEtv"
   },
   "id": "6JpbxntkEEtv"
  },
  {
   "cell_type": "code",
   "source": [
    "num_inputs = 52\n",
    "num_hidden1 = 64\n",
    "num_hidden2 = 128\n",
    "kernel_size = (3, 3)\n",
    "\n",
    "fully_convolutional_layers = torch.nn.Sequential(\n",
    "    torch.nn.Conv2d(num_inputs, num_hidden1, kernel_size),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.ConvTranspose2d(num_hidden1, num_hidden2, kernel_size),\n",
    "    torch.nn.ReLU(),\n",
    ")\n",
    "\n",
    "fcn_outputs = fully_convolutional_layers(channels_first)\n",
    "print(f\"FCN outputs: {fcn_outputs.shape}\")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3Ima73TIEG1z",
    "outputId": "4929396e-74c2-4cd3-9fdf-31e12117f064"
   },
   "id": "3Ima73TIEG1z",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "FCN outputs: torch.Size([1, 128, 5, 5])\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "Now, let's convert the results back into channels-last format with `MoveDim`."
   ],
   "metadata": {
    "id": "TkRDEANqFoLd"
   },
   "id": "TkRDEANqFoLd"
  },
  {
   "cell_type": "code",
   "source": [
    "to_channels_last = MoveDim(1, -1)\n",
    "channels_last = to_channels_last(fcn_outputs)\n",
    "\n",
    "print(f\"channels-last: {channels_last.shape}\")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "-oAMnWtfFuzr",
    "outputId": "7a436c24-4ab0-4040-dd6f-cbbf167d03ff"
   },
   "id": "-oAMnWtfFuzr",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "channels-last: torch.Size([1, 5, 5, 128])\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "For the last layer, we use a [`Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) layer with the number of outputs we want.\n",
    "Since we can't have negative precipitation, we passed the model's outputs through a final `ReLU` activation function."
   ],
   "metadata": {
    "id": "7x_OkkNvGabm"
   },
   "id": "7x_OkkNvGabm"
  },
  {
   "cell_type": "code",
   "source": [
    "num_outputs = 2\n",
    "\n",
    "linear = torch.nn.Linear(num_hidden2, num_outputs)\n",
    "relu = torch.nn.ReLU()\n",
    "\n",
    "with torch.no_grad():\n",
    "    raw_predictions = linear(channels_last)\n",
    "    predictions = relu(raw_predictions)\n",
    "\n",
    "print(f\"predictions: {predictions.shape}\")\n",
    "print(predictions[0, 0, 0])"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "q0T8CpEaGJew",
    "outputId": "227fe92d-2f86-4160-aad9-971d08032a51"
   },
   "id": "q0T8CpEaGJew",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "predictions: torch.Size([1, 5, 5, 2])\n",
      "tensor([0.0650, 0.0010])\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CDQIGsp24EX9"
   },
   "source": [
    "In [`weather/model.py`](../serving/weather-model/weather/model.py) we defined the `WeatherModel` and `WeatherConfig` classes.\n",
    "\n",
    "The `WeatherModel` class inherits from [`PreTrainedModel`](https://huggingface.co/docs/transformers/main/en/main_classes/model) to make it compatible with [🤗 Transformers](https://huggingface.co/docs/transformers/main/en/index).\n",
    "\n",
    "The model definition includes the loss function, so it knows how good or bad their predictions were.\n",
    "We could use any regression loss function like [Mean Absolute Error (L1)](https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html) or [Mean Squared Error (L2)](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html).\n",
    "PyTorch provides a [Smooth L1 Loss](https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html), which chooses between L1 and L2 depending on a certain criteria.\n",
    "It's less sensitive to outliers, so we'll use that.\n",
    "\n",
    "To create a `WeatherModel`, we have to pass it a `WeatherConfig`.\n",
    "The `WeatherConfig` contains all the model's hyperparameters, and we must also pass the _mean_ and _standard deviation_ from the training dataset for the normalization layer.\n",
    "We defined `WeatherModel.create` which takes in the training dataset inputs and returns us a `WeatherModel` with the right `WeatherConfig`."
   ],
   "id": "CDQIGsp24EX9"
  },
  {
   "cell_type": "code",
   "source": [
    "from weather.model import WeatherModel\n",
    "\n",
    "model = WeatherModel.create(dataset[\"train\"][\"inputs\"])\n",
    "print(model)"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "h0bzkGqwo-Ic",
    "outputId": "4a0f8622-e3fd-49cf-9eb2-c9e2777174b0"
   },
   "id": "h0bzkGqwo-Ic",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "WeatherModel(\n",
      "  (layers): Sequential(\n",
      "    (0): Normalization()\n",
      "    (1): MoveDim()\n",
      "    (2): Conv2d(52, 64, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (3): ReLU()\n",
      "    (4): ConvTranspose2d(64, 128, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (5): ReLU()\n",
      "    (6): MoveDim()\n",
      "    (7): Linear(in_features=128, out_features=2, bias=True)\n",
      "    (8): ReLU()\n",
      "  )\n",
      ")\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "The model outputs a `{'loss': torch.Tensor, 'logits': torch.Tensor}` dictionary during training, and a `{'logits': torch.Tensor}` dictionary during predictions.\n",
    "This is what 🤗 Transformers expect for [model outputs](https://huggingface.co/docs/transformers/main/en/main_classes/output).\n",
    "\n",
    "Remember that we _must_ pass a _batch_ of inputs to the model, not a single input."
   ],
   "metadata": {
    "id": "6iS60sGCJczT"
   },
   "id": "6iS60sGCJczT"
  },
  {
   "cell_type": "code",
   "source": [
    "example = dataset[\"test\"]\n",
    "inputs_batch = torch.as_tensor(example[\"inputs\"][:1])\n",
    "labels_batch = torch.as_tensor(example[\"labels\"][:1])\n",
    "\n",
    "# We pass the labels as well to get the loss, but it's optional.\n",
    "# If we don't pass the labels, we simply won't get the loss.\n",
    "# The predictions are under the 'logits' key.\n",
    "with torch.no_grad():\n",
    "    predictions = model(inputs_batch, labels_batch)\n",
    "\n",
    "print(f\"inputs:      {inputs_batch.shape}\")\n",
    "print(f\"labels:      {labels_batch.shape}\")\n",
    "print(f\"loss:        {predictions['loss']}\")\n",
    "print(f\"predictions: {predictions['logits'].shape}\")\n",
    "print(\"-\" * 40)\n",
    "print(f\"sample labels:      {labels_batch[0, 0, 0]}\")\n",
    "print(f\"sample predictions: {predictions['logits'][0, 0, 0]}\")"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "gKwdukOzJeNA",
    "outputId": "63287630-c21b-4b21-c6b0-168423fd2746"
   },
   "id": "gKwdukOzJeNA",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "inputs:      torch.Size([1, 5, 5, 52])\n",
      "labels:      torch.Size([1, 5, 5, 2])\n",
      "loss:        0.009296745993196964\n",
      "predictions: torch.Size([1, 5, 5, 2])\n",
      "----------------------------------------\n",
      "sample labels:      tensor([0., 0.])\n",
      "sample predictions: tensor([0.0797, 0.0000])\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "These predictions don't look great because we haven't trained our model.\n",
    "Fortunately, since we've made our model compatible with 🤗 Transformers, we can simply use [`Trainer`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer), which takes care of all the training steps, automatically optimizes the whole process, and uses accelerators like GPUs if available."
   ],
   "metadata": {
    "id": "cxyoRnNlzsYu"
   },
   "id": "cxyoRnNlzsYu"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 👟 Train the model\n",
    "\n",
    "We have to define the number of times we want the model to go through the training dataset, this is called the number of _epochs_.\n",
    "We also have to define the _batch size_ we want to use during training and testing, this can have a big impact in how fast the model trains, as a rule of thumb the larger the better as long as it fits into memory.\n",
    "We define all these parameters with [`TrainingArguments`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments).\n",
    "\n",
    "Then we pass the model, the `TrainingArguments`, and the training and testing datasets into the `Trainer`.\n",
    "Finally we can train the model with [`Trainer.train`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.train)."
   ],
   "metadata": {
    "id": "xG6PnXhfLzxO"
   },
   "id": "xG6PnXhfLzxO"
  },
  {
   "cell_type": "code",
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "epochs = 5\n",
    "batch_size = 512\n",
    "\n",
    "# Define our training job.\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"checkpoints\",\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    num_train_epochs=epochs,\n",
    "    logging_strategy=\"epoch\",\n",
    "    evaluation_strategy=\"epoch\",\n",
    ")\n",
    "trainer = Trainer(\n",
    "    model,\n",
    "    training_args,\n",
    "    train_dataset=dataset[\"train\"],\n",
    "    eval_dataset=dataset[\"test\"],\n",
    ")\n",
    "\n",
    "# Run the training job.\n",
    "trainer.train()"
   ],
   "metadata": {
    "id": "x4ta1oIsMveF",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 825
    },
    "outputId": "3a4e8674-dfa2-4cbf-8445-c3bcccfe4769"
   },
   "id": "x4ta1oIsMveF",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "PyTorch: setting up devices\n",
      "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n",
      "***** Running training *****\n",
      "  Num examples = 3069\n",
      "  Num Epochs = 5\n",
      "  Instantaneous batch size per device = 512\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 512\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 30\n",
      "  Number of trainable parameters = 104234\n",
      "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ],
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='30' max='30' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [30/30 00:23, Epoch 5/5]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.288900</td>\n",
       "      <td>1.016647</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.279300</td>\n",
       "      <td>1.009680</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.271700</td>\n",
       "      <td>1.004657</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.266700</td>\n",
       "      <td>1.001499</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.263600</td>\n",
       "      <td>1.000306</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ]
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "***** Running Evaluation *****\n",
      "  Num examples = 341\n",
      "  Batch size = 512\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 341\n",
      "  Batch size = 512\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 341\n",
      "  Batch size = 512\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 341\n",
      "  Batch size = 512\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 341\n",
      "  Batch size = 512\n",
      "\n",
      "\n",
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
      "\n",
      "\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "TrainOutput(global_step=30, training_loss=1.2740394274393718, metrics={'train_runtime': 23.7216, 'train_samples_per_second': 646.878, 'train_steps_per_second': 1.265, 'total_flos': 0.0, 'train_loss': 1.2740394274393718, 'epoch': 5.0})"
      ]
     },
     "metadata": {},
     "execution_count": 31
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "> 💡 Both losses should go down every epoch, and they should be roughly similar.\n",
    "> If the training loss goes down, but the testing loss stays flat or goes up, it might be a sign that the model is [overfitting](https://developers.google.com/machine-learning/crash-course/generalization/peril-of-overfitting), meaning that it's memorizing the training dataset instead of learning to generalize."
   ],
   "metadata": {
    "id": "jPFCmhruOvjB"
   },
   "id": "jPFCmhruOvjB"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 💾 Save and load the model\n",
    "\n",
    "After the model has finished training, we can save it with [`Trainer.save_model`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.save_model).\n",
    "\n"
   ],
   "metadata": {
    "id": "_AxB_p2-z4UH"
   },
   "id": "_AxB_p2-z4UH"
  },
  {
   "cell_type": "code",
   "source": [
    "trainer.save_model(\"model\")\n",
    "\n",
    "!ls -lh model"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NPLnvRydOik0",
    "outputId": "c788cf33-ec67-4612-9f44-1262dc872625"
   },
   "id": "NPLnvRydOik0",
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Saving model checkpoint to model\n",
      "Configuration saved in model/config.json\n",
      "Model weights saved in model/pytorch_model.bin\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "total 420K\n",
      "-rw-r--r-- 1 root root 3.4K Jan 11 21:33 config.json\n",
      "-rw-r--r-- 1 root root 410K Jan 11 21:33 pytorch_model.bin\n",
      "-rw-r--r-- 1 root root 3.4K Jan 11 21:33 training_args.bin\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "icxhkboQA_o5"
   },
   "source": [
    "Now that we have a trained model, we can save it and load it anywhere else.\n",
    "We can load a 🤗 Transformers model with [`PreTrainedModel.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained), in our case with `WeatherModel.from_pretrained`.\n",
    "This loads all the model's hyperparameters as well as the _mean_ and _standard deviation_ for the normalization layer."
   ],
   "id": "icxhkboQA_o5"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xsxX2Mb-CwWk",
    "outputId": "4bc42359-88c8-4803-be4c-34cd9e07a8e1"
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "loading configuration file model/config.json\n",
      "Model config WeatherConfig {\n",
      "  \"architectures\": [\n",
      "    \"WeatherModel\"\n",
      "  ],\n",
      "  \"kernel_size\": [\n",
      "    3,\n",
      "    3\n",
      "  ],\n",
      "  \"mean\": [\n",
      "    [\n",
      "      [\n",
      "        [\n",
      "          0.965579092502594,\n",
      "          2.3415911197662354,\n",
      "          6.150100231170654,\n",
      "          476.72564697265625,\n",
      "          421.8377380371094,\n",
      "          521.5245971679688,\n",
      "          109.100830078125,\n",
      "          300.76141357421875,\n",
      "          262.6136474609375,\n",
      "          5461.68310546875,\n",
      "          2202.313232421875,\n",
      "          2355.513916015625,\n",
      "          2328.052001953125,\n",
      "          2470.915771484375,\n",
      "          2687.08056640625,\n",
      "          2737.617919921875,\n",
      "          2684.49365234375,\n",
      "          2650.0927734375,\n",
      "          2816.9892578125,\n",
      "          509.75927734375,\n",
      "          451.73077392578125,\n",
      "          535.8512573242188,\n",
      "          140.81637573242188,\n",
      "          276.422607421875,\n",
      "          257.4959411621094,\n",
      "          4964.77197265625,\n",
      "          2143.988037109375,\n",
      "          2276.671630859375,\n",
      "          2243.602783203125,\n",
      "          2340.478759765625,\n",
      "          2601.414794921875,\n",
      "          2623.432373046875,\n",
      "          2567.951904296875,\n",
      "          2536.750732421875,\n",
      "          2718.31591796875,\n",
      "          601.8285522460938,\n",
      "          540.8607788085938,\n",
      "          601.079345703125,\n",
      "          250.1461639404297,\n",
      "          271.73126220703125,\n",
      "          291.7319641113281,\n",
      "          4314.62744140625,\n",
      "          2050.633544921875,\n",
      "          2152.40283203125,\n",
      "          2113.24267578125,\n",
      "          2147.13232421875,\n",
      "          2477.33935546875,\n",
      "          2455.325927734375,\n",
      "          2397.76416015625,\n",
      "          2371.694091796875,\n",
      "          2573.267578125,\n",
      "          1774.619873046875\n",
      "        ]\n",
      "      ]\n",
      "    ]\n",
      "  ],\n",
      "  \"model_type\": \"weather\",\n",
      "  \"num_hidden1\": 64,\n",
      "  \"num_hidden2\": 128,\n",
      "  \"num_inputs\": 52,\n",
      "  \"num_outputs\": 2,\n",
      "  \"std\": [\n",
      "    [\n",
      "      [\n",
      "        [\n",
      "          3.4099764823913574,\n",
      "          5.503620147705078,\n",
      "          8.816463470458984,\n",
      "          652.4397583007812,\n",
      "          596.1942138671875,\n",
      "          676.47119140625,\n",
      "          253.98402404785156,\n",
      "          383.220947265625,\n",
      "          336.3753967285156,\n",
      "          1895.1807861328125,\n",
      "          256.8292236328125,\n",
      "          324.5935974121094,\n",
      "          332.1437072753906,\n",
      "          480.6833801269531,\n",
      "          351.2192687988281,\n",
      "          423.1502685546875,\n",
      "          439.25201416015625,\n",
      "          433.09442138671875,\n",
      "          380.0411376953125,\n",
      "          719.5698852539062,\n",
      "          652.4895629882812,\n",
      "          723.5523071289062,\n",
      "          309.5145568847656,\n",
      "          352.95697021484375,\n",
      "          330.34222412109375,\n",
      "          2052.62060546875,\n",
      "          294.1995849609375,\n",
      "          371.49407958984375,\n",
      "          375.6823425292969,\n",
      "          520.1500244140625,\n",
      "          360.0391540527344,\n",
      "          455.3539123535156,\n",
      "          472.1324157714844,\n",
      "          466.9129943847656,\n",
      "          412.6302185058594,\n",
      "          861.4622802734375,\n",
      "          791.761474609375,\n",
      "          841.4432983398438,\n",
      "          514.6561279296875,\n",
      "          439.2526550292969,\n",
      "          461.8729553222656,\n",
      "          2031.054931640625,\n",
      "          326.0904541015625,\n",
      "          409.5498962402344,\n",
      "          408.24627685546875,\n",
      "          533.72119140625,\n",
      "          348.2250671386719,\n",
      "          464.1885070800781,\n",
      "          482.2993469238281,\n",
      "          480.4700622558594,\n",
      "          428.2205505371094,\n",
      "          1641.6630859375\n",
      "        ]\n",
      "      ]\n",
      "    ]\n",
      "  ],\n",
      "  \"torch_dtype\": \"float32\",\n",
      "  \"transformers_version\": \"4.25.1\"\n",
      "}\n",
      "\n",
      "loading weights file model/pytorch_model.bin\n",
      "All model checkpoint weights were used when initializing WeatherModel.\n",
      "\n",
      "All the weights of WeatherModel were initialized from the model checkpoint at model.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use WeatherModel for predictions without further training.\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "WeatherModel(\n",
      "  (layers): Sequential(\n",
      "    (0): Normalization()\n",
      "    (1): MoveDim()\n",
      "    (2): Conv2d(52, 64, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (3): ReLU()\n",
      "    (4): ConvTranspose2d(64, 128, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (5): ReLU()\n",
      "    (6): MoveDim()\n",
      "    (7): Linear(in_features=128, out_features=2, bias=True)\n",
      "    (8): ReLU()\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from weather.model import WeatherModel\n",
    "\n",
    "model = WeatherModel.from_pretrained(\"model\")\n",
    "print(model)"
   ],
   "id": "xsxX2Mb-CwWk"
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IO73AYtsCIQ_"
   },
   "source": [
    "# ☁️ Train the model in Vertex AI\n",
    "\n",
    "> ⚠️ Training in Vertex AI doesn't currently work due to an underlying issue when using the HuggingFace `Trainer` API in Vertex AI. For more information, see [#9272](https://github.com/GoogleCloudPlatform/python-docs-samples/issues/9272).\n",
    "\n",
    "For this example we're training on a very small dataset for a very small number of epochs.\n",
    "This means we don't have a representative number of examples and the model hasn't seen the data enough times, so it won't perform very well.\n",
    "\n",
    "Training on larger datasets for a large number of epochs can take a lot of time, so it might be a good idea to do the training in Cloud.\n",
    "[Vertex AI](https://cloud.google.com/vertex-ai) is a great option, and even allows us to use hardware accelerators like GPUs.\n",
    "There are [PyTorch pre-built containers](https://cloud.google.com/vertex-ai/docs/training/pre-built-containers#pytorch) which include PyTorch and many common libraries, so we don't need to build a custom container.\n",
    "\n",
    "The model and trainer are defined in the [`serving/weather-model`](../serving/weather-model) module.\n",
    "To run it in Vertex AI, we must build the package, copy it to Cloud Storage, and launch a custom training job with [`CustomPythonPackageTrainingJob`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.CustomPythonPackageTrainingJob)."
   ],
   "id": "IO73AYtsCIQ_"
  },
  {
   "cell_type": "code",
   "source": [
    "# Build the `weather-model` package.\n",
    "!python -m build serving/weather-model"
   ],
   "metadata": {
    "id": "v1SZt1iA2Wrh"
   },
   "execution_count": null,
   "outputs": [],
   "id": "v1SZt1iA2Wrh"
  },
  {
   "cell_type": "code",
   "source": [
    "!ls -lh serving/weather-model/dist"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "2f2b4dc2-a287-4822-caed-7f8115246d7d",
    "id": "y4F1_eA32Wrh"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "total 16K\n",
      "-rw-r--r-- 1 root root 5.9K Jan 11 18:29 weather_model-1.0.0-py3-none-any.whl\n",
      "-rw-r--r-- 1 root root 4.3K Jan 11 18:29 weather-model-1.0.0.tar.gz\n"
     ]
    }
   ],
   "id": "y4F1_eA32Wrh"
  },
  {
   "cell_type": "code",
   "source": [
    "# Stage the `weather-model` package in Cloud Storage.\n",
    "!gsutil cp serving/weather-model/dist/weather-model-1.0.0.tar.gz gs://{bucket}/weather/"
   ],
   "metadata": {
    "id": "JA1k9ky02dsx"
   },
   "id": "JA1k9ky02dsx",
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "In Vertex AI, we can access Cloud Storage files directly as if they were local files via Cloud Storage FUSE.\n",
    "Cloud Storage files are available under `/gcs` followed by your bucket and file path.\n",
    "To learn more, see the [Cloud Storage as a File System in AI Training](https://cloud.google.com/blog/products/ai-machine-learning/cloud-storage-file-system-ai-training) blog post."
   ],
   "metadata": {
    "id": "yk9X4YQcDPpR"
   },
   "id": "yk9X4YQcDPpR"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ny4x99GiS2Lm"
   },
   "outputs": [],
   "source": [
    "from google.cloud import aiplatform\n",
    "\n",
    "epochs = 100\n",
    "timeout_min = 60  # 1 hour\n",
    "\n",
    "# Cloud Storage paths.\n",
    "data_path = f\"/gcs/{bucket}/weather/data\"\n",
    "model_path = f\"/gcs/{bucket}/weather/model\"\n",
    "\n",
    "aiplatform.init(project=project, location=location, staging_bucket=bucket)\n",
    "\n",
    "# Launch the custom training job.\n",
    "job = aiplatform.CustomPythonPackageTrainingJob(\n",
    "    display_name=\"weather-forecasting\",\n",
    "    python_package_gcs_uri=f\"gs://{bucket}/weather/weather-model-1.0.0.tar.gz\",\n",
    "    python_module_name=\"weather.trainer\",\n",
    "    container_uri=\"us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.2-8.py310:latest\",\n",
    ")\n",
    "job.run(\n",
    "    machine_type=\"n1-highmem-8\",\n",
    "    accelerator_type=\"NVIDIA_TESLA_T4\",\n",
    "    accelerator_count=1,\n",
    "    args=[\n",
    "        f\"--data-path={data_path}\",\n",
    "        f\"--model-path={model_path}\",\n",
    "        f\"--epochs={epochs}\",\n",
    "    ],\n",
    "    timeout=timeout_min * 60,  # in seconds\n",
    ")"
   ],
   "id": "Ny4x99GiS2Lm"
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zw_kcyw4gOLF"
   },
   "source": [
    "> 💡 Look at your Vertex AI training jobs: https://console.cloud.google.com/vertex-ai/training/custom-jobs"
   ],
   "id": "zw_kcyw4gOLF"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# ⛳️ What's next?\n",
    "\n",
    "* [![Open in Colab](https://github.com/googlecolab/open_in_colab/raw/main/images/icon16.png) **🔮 Model predictions**](https://colab.research.google.com/github/GoogleCloudPlatform/python-docs-samples/blob/main/people-and-planet-ai/weather-forecasting/notebooks/4-predictions.ipynb):\n",
    "  Get predictions from the model with data it has never seen before."
   ],
   "metadata": {
    "id": "79RnF-lYBRTS"
   },
   "id": "79RnF-lYBRTS"
  }
 ],
 "metadata": {
  "colab": {
   "provenance": [],
   "toc_visible": true
  },
  "environment": {
   "kernel": "python3",
   "name": "tf2-gpu.2-6.m82",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-6:m82"
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
